#!/usr/bin/env python3 -u
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree.
from __future__ import print_function

import warnings
warnings.filterwarnings("ignore", category=UserWarning)

import csv
import os

import numpy as np
import torch
from torch.autograd import Variable, grad
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from tqdm import tqdm

from torch.utils.tensorboard import SummaryWriter

from utils import random_perturb, make_step, inf_data_gen, Logger
from utils import soft_cross_entropy, classwise_loss, FocalLoss, LDAMLoss, BalancedSoftmaxLoss, ClassBalancedSoftmax
from Sinkhorn_distance import SinkhornDistance_one_to_multi
from config import *

fea_cosine = nn.CosineSimilarity(dim=-1, eps=1e-8)

LOGNAME = 'Imbalance_' + LOGFILE_BASE
logger = Logger(LOGNAME)
LOGDIR = logger.logdir

writer = SummaryWriter(LOGDIR)
LOG_CSV = os.path.join(LOGDIR, f'log_{SEED}.csv')
LOG_CSV_HEADER = [
    'epoch', 'train loss', 'gen loss', 'train acc', 'gen_acc', 'prob_orig', 'prob_targ',
    'test loss', 'major test acc', 'neutral test acc', 'minor test acc', 'test acc', 'f1 score'
]
if not os.path.exists(LOG_CSV):
    with open(LOG_CSV, 'w') as f:
        csv_writer = csv.writer(f, delimiter=',')
        csv_writer.writerow(LOG_CSV_HEADER)

def save_checkpoint(acc, model, optim, epoch, index=False):
    # Save checkpoint.
    print('Saving..')

    if isinstance(model, nn.DataParallel):
        model = model.module

    state = {
        'net': model.state_dict(),
        'optimizer': optim.state_dict(),
        'acc': acc,
        'epoch': epoch,
        'rng_state': torch.get_rng_state()
    }

    if index:
        ckpt_name = 'ckpt_epoch' + str(epoch) + '_' + str(SEED) + '.t7'
    else:
        ckpt_name = 'ckpt_' + str(SEED) + '.t7'

    ckpt_path = os.path.join(LOGDIR, ckpt_name)
    torch.save(state, ckpt_path)


def train_epoch(model, criterion, optimizer, data_loader, logger=None):
    model.train()

    train_loss = 0
    correct = 0
    total = 0

    for inputs, targets in tqdm(data_loader):
        # For SMOTE, get the samples from smote_loader instead of usual loader
        if epoch >= ARGS.warm and ARGS.smote:
            inputs, targets = next(smote_loader_inf)

        inputs, targets = inputs.to(device), targets.to(device)
        batch_size = inputs.size(0)
    
        outputs, _ = model(normalizer(inputs))
        loss = criterion(outputs, targets).mean()

        train_loss += loss.item() * batch_size
        predicted = outputs.max(1)[1]
        total += batch_size
        correct += sum_t(predicted.eq(targets))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    msg = 'Loss: %.3f| Acc: %.3f%% (%d/%d)' % \
          (train_loss / total, 100. * correct / total, correct, total)
    if logger:
        logger.log(msg)
    else:
        print(msg)

    return train_loss / total, 100. * correct / total


def uniform_loss(outputs):
    weights = torch.ones_like(outputs) / N_CLASSES

    return soft_cross_entropy(outputs, weights, reduction='mean')

def classwise_loss(outputs, targets):
    out_1hot = torch.zeros_like(outputs)
    out_1hot.scatter_(1, targets.view(-1, 1), 1)
    return (outputs * out_1hot).sum(1).mean()

def multi_classwise_loss(outputs, targets, confusion_targets):
    out_1hot = torch.ones_like(outputs)
    out_1hot.scatter_(1, targets.view(-1, 1), -1)

    out_1hot_2 = torch.ones_like(outputs)
    out_1hot_2.scatter_(1, confusion_targets.view(-1, 1), -1)

    return (outputs * out_1hot).sum(1).mean() + (outputs * out_1hot_2).sum(1).mean()

def multi_classwise_loss_2(outputs, targets, confusion_targets):
    out_1hot = torch.ones_like(outputs)
    out_1hot.scatter_(1, targets.view(-1, 1), -1)

    out_1hot_2 = torch.ones_like(outputs)
    out_1hot_2.scatter_(1, confusion_targets.view(-1, 1), -1)

    one_hot = torch.logical_or(out_1hot, out_1hot_2).int().to(outputs.device)

    # binary_outputs = torch.sigmoid(outputs) * one_hot
    # binary_outputs = gelu(outputs) * one_hot
    # binary_outputs = F.relu(outputs) * one_hot

    binary_outputs = outputs**3 * one_hot

    return binary_outputs.sum(1).mean()

def ED_loss(target_data, synth_data):

    x_col = target_data.unsqueeze(-2)
    y_lin = synth_data.unsqueeze(-3)
    Cost1 = torch.sum((x_col - y_lin) ** 2, -1) #+0.1*torch.sum((torch.abs(x_col - y_lin)) ** 1, -1)

    x_col = synth_data.unsqueeze(-2)
    Cost2 = torch.sum((x_col - y_lin) ** 2, -1) #+0.1*torch.sum((torch.abs(x_col - y_lin)) ** 1, -1)

    x_col = target_data.unsqueeze(-2)
    y_lin = target_data.unsqueeze(-3)
    Cost3 = torch.sum((x_col - y_lin) ** 2, -1) #+0.1*torch.sum((torch.abs(x_col - y_lin)) ** 1, -1)

    loss=2*torch.mean(Cost1)-torch.mean(Cost2)-torch.mean(Cost3)

    return loss


def MMD_loss(target_data, synth_data):
    loss = torch.sum((torch.mean(target_data, dim=0) - torch.mean(synth_data, dim=0))**2)
    return loss

def random_sample_ori_data(ori_data_list):
    sampled_data = []
    
    for od in ori_data_list:
        sampled_data.append(np.random.permutation(od)[:N_SAMPLES_PER_CLASS[-1]])
        # sampled_data.append(np.random.permutation(od)[:100])

    sampled_data = np.asarray(sampled_data, dtype=np.float64)
    sampled_data = torch.from_numpy(sampled_data).to(torch.float).to(device)

    return sampled_data


def avg_sample_dis(selected_features):

    x_col = selected_features.unsqueeze(-2)
    y_lin = selected_features.unsqueeze(-3)
    cosine_distance_matrix =  1-fea_cosine(x_col , y_lin)

    return cosine_distance_matrix.sum(1).mean(1)


def generation(model_g, model_r, inputs, seed_targets, targets, p_accept,
               gamma, lam, step_size, ori_data, random_start=True, max_iter=10):
    model_g.eval()
    model_r.eval()
    criterion = nn.CrossEntropyLoss()

    random_net = models.__dict__[MODEL](N_CLASSES)
    random_net = random_net.to(device)

    # channel_info = ori_data[1]
    ori_data = ori_data[0]
    sampled_ori_data = random_sample_ori_data(ori_data)

    sampled_ori_data = sampled_ori_data.reshape(-1, 3, 32, 32)
    normalizaed_sampled_ori_data = normalizer(sampled_ori_data)
    _, features = random_net(normalizaed_sampled_ori_data)
    features = features[-1].reshape(N_CLASSES, -1, 64)

    # targets = erased_targeted_targets + confusd_targets
    confusion_targets = argmax_confusion_target[seed_targets]

    if random_start:
        random_noise = random_perturb(inputs, 'l2', 0.5)
        inputs = torch.clamp(inputs + random_noise, 0, 1)

    selected_features = features.index_select(0, targets)

    for _ in range(max_iter):
        inputs = inputs.clone().detach().requires_grad_(True)
        outputs_g, _ = model_g(normalizer(inputs))
        outputs_r, _ = model_r(normalizer(inputs))

        _, random_outputs_g_features = random_net(normalizer(inputs))
        random_outputs_g_features = random_outputs_g_features[-1]
 
        if ARGS.name == 'M2m' or 'MMD':
            loss = criterion(outputs_g, targets) + lam * classwise_loss(outputs_r, seed_targets) \
                    + 0.5 * MMD_loss(selected_features, random_outputs_g_features) + lam * classwise_loss(outputs_r, confusion_targets)
                    
        if ARGS.name == 'ED':
            loss = criterion(outputs_g, targets) + lam * classwise_loss(outputs_r, seed_targets) \
                    + 0.5 * ED_loss(selected_features, random_outputs_g_features) + lam * classwise_loss(outputs_r, confusion_targets)
 
        if ARGS.name == 'OT': 
            cost, _pi, _c = OT_loss(random_outputs_g_features, selected_features)

            loss = criterion(outputs_g, targets) + lam * classwise_loss(outputs_r, seed_targets) + lam * cost.sum() \
             + lam * classwise_loss(outputs_r, confusion_targets)
            
        grad, = torch.autograd.grad(loss, [inputs])

        inputs = inputs - make_step(grad, 'l2', step_size)
        inputs = torch.clamp(inputs, 0, 1)


    inputs = inputs.detach()

    outputs_g, _ = model_g(normalizer(inputs))

    one_hot = torch.zeros_like(outputs_g)
    one_hot.scatter_(1, targets.view(-1, 1), 1)
    probs_g = torch.softmax(outputs_g, dim=1)[one_hot.to(torch.bool)]

    correct = (probs_g >= gamma) * torch.bernoulli(p_accept).byte().to(device)

    # mean_dis = _c.sum(1).mean(1)
    # mean_dis = cost
    # cos_dis = avg_sample_dis(selected_features)

    # correct = (probs_g >= gamma) * (mean_dis <= cos_dis).byte().to(device)
    # correct = (probs_g >= gamma)

    model_r.train()

    return inputs, correct


def train_net(model_train, model_gen, criterion, optimizer_train, inputs_orig, targets_orig, gen_idx, gen_targets, ori_data):
    batch_size = inputs_orig.size(0)

    inputs = inputs_orig.clone()
    targets = targets_orig.clone()

    ########################

    bs = N_SAMPLES_PER_CLASS_T[targets_orig].repeat(gen_idx.size(0), 1)
    gs = N_SAMPLES_PER_CLASS_T[gen_targets].view(-1, 1)

    delta = F.relu(bs - gs)
    p_accept = 1 - ARGS.beta ** delta
    mask_valid = (p_accept.sum(1) > 0)

    gen_idx = gen_idx[mask_valid]
    gen_targets = gen_targets[mask_valid]
    p_accept = p_accept[mask_valid]

    select_idx = torch.multinomial(p_accept, 1, replacement=True).view(-1)
    p_accept = p_accept.gather(1, select_idx.view(-1, 1)).view(-1)

    seed_targets = targets_orig[select_idx]
    seed_images = inputs_orig[select_idx]

    gen_inputs, correct_mask = generation(model_gen, model_train, seed_images, seed_targets, gen_targets, p_accept,
                                          ARGS.gamma, ARGS.lam, ARGS.step_size, ori_data, True, ARGS.attack_iter)

    ########################

    # Only change the correctly generated samples
    num_gen = sum_t(correct_mask)
    num_others = batch_size - num_gen

    gen_c_idx = gen_idx[correct_mask]
    others_mask = torch.ones(batch_size, dtype=torch.bool, device=device)
    others_mask[gen_c_idx] = 0
    others_idx = others_mask.nonzero().view(-1)

    if num_gen > 0:
        gen_inputs_c = gen_inputs[correct_mask]
        gen_targets_c = gen_targets[correct_mask]

        inputs[gen_c_idx] = gen_inputs_c
        targets[gen_c_idx] = gen_targets_c

    outputs, _ = model_train(normalizer(inputs))
    loss = criterion(outputs, targets)

    optimizer_train.zero_grad()
    loss.mean().backward()
    optimizer_train.step()

    # For logging the training

    oth_loss_total = sum_t(loss[others_idx])
    gen_loss_total = sum_t(loss[gen_c_idx])

    _, predicted = torch.max(outputs[others_idx].data, 1)
    num_correct_oth = sum_t(predicted.eq(targets[others_idx]))

    num_correct_gen, p_g_orig, p_g_targ = 0, 0, 0
    success = torch.zeros(N_CLASSES, 2)

    if num_gen > 0:
        _, predicted_gen = torch.max(outputs[gen_c_idx].data, 1)
        num_correct_gen = sum_t(predicted_gen.eq(targets[gen_c_idx]))
        probs = torch.softmax(outputs[gen_c_idx], 1).data

        p_g_orig = probs.gather(1, seed_targets[correct_mask].view(-1, 1))
        p_g_orig = sum_t(p_g_orig)

        p_g_targ = probs.gather(1, gen_targets_c.view(-1, 1))
        p_g_targ = sum_t(p_g_targ)

    for i in range(N_CLASSES):
        if num_gen > 0:
            success[i, 0] = sum_t(gen_targets_c == i)
        success[i, 1] = sum_t(gen_targets == i)

    return oth_loss_total, gen_loss_total, num_others, num_correct_oth, num_gen, num_correct_gen, p_g_orig, p_g_targ, success


def train_gen_epoch(net_t, net_g, criterion, optimizer, data_loader, ori_data):
    net_t.train()
    net_g.eval()

    oth_loss, gen_loss = 0, 0
    correct_oth = 0
    correct_gen = 0
    total_oth, total_gen = 1e-6, 1e-6
    p_g_orig, p_g_targ = 0, 0
    t_success = torch.zeros(N_CLASSES, 2)

    for inputs, targets in tqdm(data_loader):
        batch_size = inputs.size(0)
        inputs, targets = inputs.to(device), targets.to(device)

        # Set a generation target for current batch with re-sampling
        if ARGS.imb_type != 'none':  # Imbalanced
            # Keep the sample with this probability
            gen_probs = N_SAMPLES_PER_CLASS_T[targets] / N_SAMPLES_PER_CLASS_T[0]
            gen_index = (1 - torch.bernoulli(gen_probs)).nonzero()    # Generation index
            gen_index = gen_index.view(-1)
            gen_targets = targets[gen_index]
        else:   # Balanced
            gen_index = torch.arange(batch_size).view(-1)
            gen_targets = torch.randint(N_CLASSES, (batch_size,)).to(device).long()

        t_loss, g_loss, num_others, num_correct, num_gen, num_gen_correct, p_g_orig_batch, p_g_targ_batch, success \
            = train_net(net_t, net_g, criterion, optimizer, inputs, targets, gen_index, gen_targets, ori_data)

        oth_loss += t_loss
        gen_loss += g_loss
        total_oth += num_others
        correct_oth += num_correct
        total_gen += num_gen
        correct_gen += num_gen_correct
        p_g_orig += p_g_orig_batch
        p_g_targ += p_g_targ_batch
        t_success += success

    res = {
        'train_loss': oth_loss / total_oth,
        'gen_loss': gen_loss / total_gen,
        'train_acc': 100. * correct_oth / total_oth,
        'gen_acc': 100. * correct_gen / total_gen,
        'p_g_orig': p_g_orig / total_gen,
        'p_g_targ': p_g_targ / total_gen,
        't_success': t_success
    }

    msg = 't_Loss: %.3f | g_Loss: %.3f | Acc: %.3f%% (%d/%d) | Acc_gen: %.3f%% (%d/%d) ' \
          '| Prob_orig: %.3f | Prob_targ: %.3f' % (
        res['train_loss'], res['gen_loss'],
        res['train_acc'], correct_oth, total_oth,
        res['gen_acc'], correct_gen, total_gen,
        res['p_g_orig'], res['p_g_targ']
    )
    if logger:
        logger.log(msg)
    else:
        print(msg)

    return res

def group_data(ori_dataset):
    ori_images = ori_dataset.data/255
    ori_labels = ori_dataset.targets

    # Group data by label
    grouped_data = [[] for _ in range(N_CLASSES)]
    for i in range(len(ori_labels)):
        grouped_data[ori_labels[i]].append(ori_images[i])
    
    stat_info = []
    for ch in range(ori_images.shape[-1]):
        stat_info.append([ori_images[:,:,:,ch].mean(), ori_images[:,:,:,ch].std()])

    return [grouped_data, stat_info]
    

if __name__ == '__main__':
    TEST_ACC = 0  # best test accuracy
    BEST_VAL = 0  # best validation accuracy

    eplisons = 0.1
    OT_loss = SinkhornDistance_one_to_multi(eps=eplisons, max_iter=200, reduction=None).to(device)

    # Weights for virtual samples are generated
    logger.log('==> Building model: %s' % MODEL)
    net = models.__dict__[MODEL](N_CLASSES)
    net_seed = models.__dict__[MODEL](N_CLASSES)

    net, net_seed = net.to(device), net_seed.to(device)
    optimizer = optim.SGD(net.parameters(), lr=ARGS.lr, momentum=0.9, weight_decay=ARGS.decay)

    if ARGS.resume:
        # Load checkpoint.
        logger.log('==> Resuming from checkpoint..')
        ckpt_g = f'./checkpoint/{DATASET}/ratio{ARGS.ratio}/erm_trial1_{MODEL}.t7'

        if ARGS.net_both is not None:
            ckpt_t = torch.load(ARGS.net_both)
            net.load_state_dict(ckpt_t['net'])
            optimizer.load_state_dict(ckpt_t['optimizer'])
            START_EPOCH = ckpt_t['epoch'] + 1
            net_seed.load_state_dict(ckpt_t['net2'])
        else:
            if ARGS.net_t is not None:
                ckpt_t = torch.load(ARGS.net_t)
                net.load_state_dict(ckpt_t['net'])
                optimizer.load_state_dict(ckpt_t['optimizer'])
                START_EPOCH = ckpt_t['epoch'] + 1

            if ARGS.net_g is not None:
                ckpt_g = ARGS.net_g
                print(ckpt_g)
                ckpt_g = torch.load(ckpt_g)
                net_seed.load_state_dict(ckpt_g['net'])

    if N_GPUS > 1:
        logger.log('Multi-GPU mode: using %d GPUs for training.' % N_GPUS)
        net = nn.DataParallel(net)
        net_seed = nn.DataParallel(net_seed)
    elif N_GPUS == 1:
        logger.log('Single-GPU mode.')

    logger.log('==> Start Epoch: ' + str(START_EPOCH))
    SUCCESS = torch.zeros(EPOCH, N_CLASSES, 2)
    test_stats = {}

    for epoch in range(START_EPOCH, EPOCH):
        # if ARGS.name == 'ERM':
        adjust_learning_rate(optimizer, LR, epoch)

        logger.log(' * Epoch %d: %s' % (epoch, LOGDIR))
        logger.log('Epoch %d: %s' % (epoch, optimizer.param_groups[-1]['lr']))

        if epoch == START_EPOCH:
            argmax_confusion_target, confusion_matrix = init_confusion()

        if epoch == START_EPOCH:
            if ARGS.smote:
                logger.log("=============== Applying smote sampling ===============")
                smote_loader, _, _ = get_smote(DATASET, N_SAMPLES_PER_CLASS, BATCH_SIZE, transform_train, transform_test)
                smote_loader_inf = inf_data_gen(smote_loader)
            else:
                logger.log("=============== Applying over sampling ===============")
                train_loader, ori_dataset = get_oversampled(DATASET, N_SAMPLES_PER_CLASS, BATCH_SIZE,
                                                     transform_train, transform_test)

        ## For Cost-Sensitive Learning ##
        if epoch >= ARGS.warm:
            beta = 0.9999
            if beta < 1:
                effective_num = 1.0 - np.power(beta, N_SAMPLES_PER_CLASS)
                per_cls_weights = (1.0 - beta) / np.array(effective_num)
            else:
                per_cls_weights = 1 / np.array(N_SAMPLES_PER_CLASS)
                
            per_cls_weights = per_cls_weights / np.sum(per_cls_weights) * len(N_SAMPLES_PER_CLASS)
            per_cls_weights = torch.FloatTensor(per_cls_weights).to(device)
        else:
            per_cls_weights = torch.ones(N_CLASSES).to(device)

        ## Choos a loss function ##
        print(ARGS.loss_type)
        print(per_cls_weights)

        if ARGS.loss_type == 'CE':
            criterion = nn.CrossEntropyLoss(weight=per_cls_weights, reduction='none').to(device)
        elif ARGS.loss_type == 'Focal':
            criterion = FocalLoss(weight=per_cls_weights, gamma=ARGS.focal_gamma, reduction='none').to(device)
        elif ARGS.loss_type == 'LDAM':
            criterion = LDAMLoss(cls_num_list=N_SAMPLES_PER_CLASS, max_m=0.5, s=2, weight=per_cls_weights,
                                 reduction='none').to(device)
        elif ARGS.loss_type == 'BS':
            criterion = BalancedSoftmaxLoss(cls_num_list=N_SAMPLES_PER_CLASS, weight=per_cls_weights).to(device)
        elif ARGS.loss_type == 'CB':
            criterion = ClassBalancedSoftmax(cls_num_list=N_SAMPLES_PER_CLASS).to(device)
        else:
            raise ValueError("Wrong Loss Type")

        ## Training ( ARGS.warm is used for deferred re-balancing ) ##

        if epoch >= ARGS.warm and ARGS.gen:
            ori_data = group_data(ori_dataset)
            train_stats = train_gen_epoch(net, net_seed, criterion, optimizer, train_loader, ori_data)
            SUCCESS[epoch, :, :] = train_stats['t_success'].float()
            logger.log(SUCCESS[epoch, -10:, :])
            np.save(LOGDIR + '/success.npy', SUCCESS.cpu().numpy())
            logger.log("Confusion Target: {}".format(argmax_confusion_target))
            
            # scheduler.step()
        else:
            train_loss, train_acc = train_epoch(net, criterion, optimizer, train_loader, logger)
            train_stats = {'train_loss': train_loss, 'train_acc': train_acc}
            if epoch == 159:
                save_checkpoint(train_acc, net, optimizer, epoch, True)

        ## Evaluation ##
        test_eval, current_confusion_target = evaluate(net, test_loader, logger=logger, save_res=True)
        
        dynamic_confusion = False
        if dynamic_confusion:
            argmax_confusion_target = current_confusion_target

        test_output = test_eval['acc']

        if test_output >= BEST_VAL:
            BEST_VAL = test_eval['acc']

            TEST_ACC = test_eval['acc']
            TEST_ACC_CLASS = test_eval['class_acc']

            save_checkpoint(TEST_ACC, net, optimizer, epoch)
            logger.log("========== Class-wise test performance ( avg : {} ) ==========".format(TEST_ACC_CLASS.mean()))
            np.save(LOGDIR + '/classwise_acc.npy', TEST_ACC_CLASS.cpu())

        writer.add_scalar('acc/Acc', test_eval['acc'], epoch)
        writer.add_scalar('acc/Major', test_eval['major_acc'], epoch)
        writer.add_scalar('acc/Neutral', test_eval['neutral_acc'], epoch)
        writer.add_scalar('acc/Minor', test_eval['minor_acc'], epoch)
        # writer.add_scalar('class_acc', test_eval['class_acc'], epoch)
        writer.add_scalar('lr', optimizer.param_groups[-1]['lr'], epoch)

        def _convert_scala(x):
            if hasattr(x, 'item'):
                x = x.item()
            return x

        log_tr = ['train_loss', 'gen_loss', 'train_acc', 'gen_acc', 'p_g_orig', 'p_g_targ']
        log_te = ['loss', 'major_acc', 'neutral_acc', 'minor_acc', 'acc', 'f1_score']

        log_vector = [epoch] + [train_stats.get(k, 0) for k in log_tr] + [test_stats.get(k, 0) for k in log_te]
        log_vector = list(map(_convert_scala, log_vector))

        with open(LOG_CSV, 'a') as f:
            logwriter = csv.writer(f, delimiter=',')
            logwriter.writerow(log_vector)

    logger.log(' * %s' % LOGDIR)
    logger.log("Best Accuracy : {}".format(TEST_ACC))
